# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
import random
from typing import Union, List

from atk.case_generator.generator.generate_types import GENERATOR_REGISTRY
from atk.case_generator.generator.base_generator import CaseGenerator
from atk.configs.case_config import InputCaseConfig, CaseConfig


@GENERATOR_REGISTRY.register("ascend_generate_moe_gating_topK_softmax_v2")
class GroupNormGenerator(CaseGenerator):

    def __init__(self, config):
        super().__init__(config)
        self.tensor_dim = 0
        self.range_is_null = False

    def after_case_config(self, case_config: CaseConfig) -> CaseConfig:
        x_shape = case_config.inputs[0].shape
        case_config.inputs[1].shape = x_shape[:-1]
        case_config.inputs[2].range_values = x_shape[-1]

        return case_config